# -*- coding: utf-8 -*-
# @Time    : 2023/3/20 19:04
# @Author  : xiehou
# @File    : _utils.py
# @Software: PyCharm
from torch.utils.data import Dataset


def sample_equal_to(sample1, sample2):
    assert sample1["id"] == sample2["id"]
    assert sample1["text"] == sample2["text"]
    memory_set = set()
    for rel in sample2["relation_list"]:
        memory = "{}\u2E80{}\u2E80{}\u2E80{}\u2E80{}\u2E80{}\u2E80{}".format(rel["subject"],
                                                                             rel["predicate"],
                                                                             rel["object"],
                                                                             *rel["subj_tok_span"],
                                                                             *rel["obj_tok_span"])
        memory_set.add(memory)
    for rel in sample1["relation_list"]:
        memory = "{}\u2E80{}\u2E80{}\u2E80{}\u2E80{}\u2E80{}\u2E80{}".format(rel["subject"],
                                                                             rel["predicate"],
                                                                             rel["object"],
                                                                             *rel["subj_tok_span"],
                                                                             *rel["obj_tok_span"])
        if memory not in memory_set:
            # set_trace()
            return False
    return True


class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)
